library(tidyverse)
myggplot <- function(..., coeff = 1) {
bigstatsr:::MY_THEME(ggplot(...), coeff = coeff)
}
plot_results <- function(results, y, ylab = y) {
dist <- "Distribution\nof effects"
myggplot(results) +
geom_boxplot(aes_string("method", y, color = "par.dist",
fill = "par.dist"), alpha = 0.3) +
theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
facet_grid(par.model ~ par.causal) +
theme(strip.text.x = element_text(size = rel(2)),
strip.text.y = element_text(size = rel(2))) +
labs(x = "Method", y = ylab, fill = dist, color = dist)
}
boot <- function(x, n = 1e5, f = mean) {
sd(replicate(n, f(sample(x, replace = TRUE))))
}
top10 <- 1:110
top20 <- 1:220
# Put all results in a single tibble
results <- list.files("results1", full.names = TRUE) %>%
map_dfr(~readRDS(.x)) %>%
as_tibble() %>%
mutate(
par.causal = factor(map_chr(par.causal, ~paste(.x[1], .x[2], sep = " in ")),
levels = c("30 in HLA", paste(3 * 10^(1:3), "in all"))),
AUC = map_dbl(eval, ~bigstatsr::AUC(.x[, 1], .x[, 2])),
percCases10 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top10], 2])),
percCases20 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top20], 2]))
)
results %>%
filter(method %in% c("T-Trees", "logit-simple")) %>%
group_by_at(c(vars(starts_with("par")), "method")) %>%
summarise_at(c("timing", "nb.preds", "AUC", "percCases10", "percCases20"), mean) %>%
print(n = Inf)
## # A tibble: 32 x 10
## # Groups: par.causal, par.dist, par.h2, par.model [?]
## par.causal par.dist par.h2 par.model method timing nb.preds AUC percCases10 percCases20
## <fctr> <chr> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 30 in HLA gaussian 0.8 fancy logit-simple 509.9386 793.6 0.9001502 0.9272727 0.8245455
## 2 30 in HLA gaussian 0.8 fancy T-Trees 2314.0056 3300.8 0.8177493 0.8000000 0.6818182
## 3 30 in HLA gaussian 0.8 simple logit-simple 489.3918 948.0 0.9424040 0.9800000 0.8809091
## 4 30 in HLA gaussian 0.8 simple T-Trees 2410.5636 5768.6 0.8347770 0.8418182 0.6990909
## 5 30 in HLA laplace 0.8 fancy logit-simple 500.2442 966.2 0.9008663 0.9054545 0.8272727
## 6 30 in HLA laplace 0.8 fancy T-Trees 2805.8358 4542.8 0.8607568 0.8127273 0.7427273
## 7 30 in HLA laplace 0.8 simple logit-simple 500.8506 722.0 0.9302174 0.9636364 0.8681818
## 8 30 in HLA laplace 0.8 simple T-Trees 1644.6880 1995.0 0.8481380 0.8509091 0.7300000
## 9 30 in all gaussian 0.8 fancy logit-simple 518.2830 655.6 0.8775273 0.8745455 0.7709091
## 10 30 in all gaussian 0.8 fancy T-Trees 3483.2610 7330.4 0.8502545 0.8563636 0.7272727
## 11 30 in all gaussian 0.8 simple logit-simple 505.6132 649.0 0.9323514 0.9672727 0.8918182
## 12 30 in all gaussian 0.8 simple T-Trees 2922.9090 5472.8 0.8310516 0.8490909 0.7327273
## 13 30 in all laplace 0.8 fancy logit-simple 510.6026 999.2 0.8932313 0.9036364 0.8118182
## 14 30 in all laplace 0.8 fancy T-Trees 4265.4018 7901.2 0.8457449 0.8200000 0.7318182
## 15 30 in all laplace 0.8 simple logit-simple 496.2482 694.8 0.9249049 0.9636364 0.8700000
## 16 30 in all laplace 0.8 simple T-Trees 2274.9888 4626.0 0.8322151 0.8109091 0.7118182
## 17 300 in all gaussian 0.8 fancy logit-simple 572.4542 1677.0 0.7796750 0.7254545 0.6481818
## 18 300 in all gaussian 0.8 fancy T-Trees 12507.7136 32797.4 0.6265848 0.5454545 0.4581818
## 19 300 in all gaussian 0.8 simple logit-simple 554.6804 2396.6 0.8527230 0.8600000 0.7381818
## 20 300 in all gaussian 0.8 simple T-Trees 10242.1552 29357.4 0.6003698 0.4181818 0.4009091
## 21 300 in all laplace 0.8 fancy logit-simple 552.4216 1785.4 0.8259872 0.7981818 0.7081818
## 22 300 in all laplace 0.8 fancy T-Trees 7936.3082 18462.2 0.7020579 0.5854545 0.5309091
## 23 300 in all laplace 0.8 simple logit-simple 541.8522 2170.2 0.8604749 0.8545455 0.7381818
## 24 300 in all laplace 0.8 simple T-Trees 6510.7720 18024.0 0.6535618 0.5381818 0.4709091
## 25 3000 in all gaussian 0.8 fancy logit-simple 604.9040 5401.0 0.5738792 0.3945455 0.3727273
## 26 3000 in all gaussian 0.8 fancy T-Trees 16291.6926 40748.6 0.5067176 0.2981818 0.2954545
## 27 3000 in all gaussian 0.8 simple logit-simple 594.6762 6442.2 0.5906757 0.4309091 0.4100000
## 28 3000 in all gaussian 0.8 simple T-Trees 13911.6304 40190.8 0.5118337 0.3127273 0.3172727
## 29 3000 in all laplace 0.8 fancy logit-simple 593.8938 3416.2 0.6326039 0.5018182 0.4536364
## 30 3000 in all laplace 0.8 fancy T-Trees 17154.7194 40435.0 0.5151644 0.3036364 0.3136364
## 31 3000 in all laplace 0.8 simple logit-simple 585.4728 5191.2 0.6329355 0.4927273 0.4754545
## 32 3000 in all laplace 0.8 simple T-Trees 14272.6158 40443.8 0.5019465 0.3109091 0.3027273
ttrees_vs_logit <- filter(results, method %in% c("T-Trees", "logit-simple"))
p_list <- list(
plot_results(ttrees_vs_logit, "timing", "Timing (in seconds)") +
scale_y_continuous(breaks = 0:10 * 2000, minor_breaks = NULL),
plot_results(ttrees_vs_logit, "nb.preds", "Number of predictors (log-scale)") +
scale_y_log10(breaks = c(10^(0:7), 3 * 10^(0:7)), minor_breaks = NULL,
labels = scales::comma_format()),
plot_results(ttrees_vs_logit, "AUC") +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10),
plot_results(ttrees_vs_logit, "percCases10", "Percentage of cases in top 10%") +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
)
lapply(p_list, function(p) p + theme(legend.position = "none")) %>%
cowplot::plot_grid(plotlist = ., ncol = 2, align = "hv", scale = 0.9,
labels = LETTERS[1:4], label_size = 25) %>%
cowplot::plot_grid(cowplot::get_legend(p_list[[1]]),
rel_widths = c(1, 0.15))
Results of T-Trees vs penalized logistic regression. A. Timing (in seconds). B. Number of predictors of the model. C. AUC. D. Percentage of cases in the 10% largest scores.
ggsave("figures/ttrees.pdf", scale = 1/90, width = 1580, height = 1070)
top10 <- 1:110
top20 <- 1:220
# Put all results in a single tibble
results2 <- list.files("results2", full.names = TRUE) %>%
map_dfr(~readRDS(.x)) %>%
as_tibble() %>%
mutate(
par.causal = factor(map_chr(par.causal, ~paste(.x[1], .x[2], sep = " in ")),
levels = c("30 in HLA", paste(3 * 10^(1:3), "in all"))),
AUC = map_dbl(eval, ~bigstatsr::AUC(.x[, 1], .x[, 2])),
percCases10 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top10], 2])),
percCases20 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top20], 2]))
)
H2 <- 0.8
cowplot::plot_grid(
results2 %>%
filter(par.h2 == H2) %>%
myggplot(aes(AUC, percCases10, color = par.dist)) +
geom_point() +
geom_smooth(method = "lm") +
theme(legend.position = c(0.8, 0.2)) +
labs(y = "Percentage of cases in top 10%",
color = "Distribution\nof effects"),
results2 %>%
filter(par.h2 == H2) %>%
myggplot(aes(AUC, percCases20, color = par.dist)) +
geom_point() +
geom_smooth(method = "lm") +
theme(legend.position = c(0.8, 0.2)) +
labs(y = "Percentage of cases in top 20%",
color = "Distribution\nof effects"),
results2 %>%
filter(par.h2 == H2) %>%
myggplot(aes(AUC, percCases10, color = par.model)) +
geom_point() +
geom_smooth(method = "lm") +
theme(legend.position = c(0.8, 0.2)) +
scale_colour_brewer(palette = "Set1") +
labs(y = "Percentage of cases in top 10%",
color = "Model"),
results2 %>%
filter(par.h2 == H2) %>%
myggplot(aes(AUC, percCases20, color = par.model)) +
geom_point() +
geom_smooth(method = "lm") +
theme(legend.position = c(0.8, 0.2)) +
scale_colour_brewer(palette = "Set1") +
labs(y = "Percentage of cases in top 20%",
color = "Model"),
labels = LETTERS[1:4], align = "hv", label_size = 25, scale = 0.95
)
Percentage of cases in the 2 highest deciles of PRSs as a function of AUC.
ggsave("figures/AUC-corr.pdf", scale = 1/90, width = 1300, height = 950)
results2 %>%
filter(par.h2 == 0.8) %>%
plot_results("AUC") +
geom_hline(yintercept = 0.94, color = "blue", linetype = 3) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All AUC results for h2=0.8 and all chromosomes
results2 %>%
filter(par.h2 == 0.8) %>%
group_by_at(c(vars(starts_with("par")), "num.simu")) %>%
mutate(AUC_rel = AUC / AUC[method == "logit-simple"]) %>%
plot_results(y = "AUC_rel", ylab = "Relative AUC / 'logit-simple'") +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All relative AUC results for h2=0.8 and all chromosomes
results2 %>%
filter(par.h2 == 0.5) %>%
plot_results("AUC") +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All results for h2=0.5 and all chromosomes
# represent h2=0.8 as a function of h2=0.5
results2 %>%
select(starts_with("par"), method, AUC) %>%
group_by(par.causal, par.dist, par.model, method, par.h2) %>%
summarise(AUC = mean(AUC)) %>%
spread(par.h2, AUC) %>%
myggplot(aes(`0.5`, `0.8`, color = method)) +
geom_smooth(size = 2, alpha = 0.2) +
geom_point(size = 3)
## `geom_smooth()` using method = 'loess'
Results of AUC for all combination of parameters and methods when h2=0.5 and h2=0.8
results2 %>%
select(starts_with("par"), method, AUC) %>%
group_by(par.causal, par.dist, par.model, method, par.h2) %>%
summarise(AUC = mean(AUC)) %>%
spread(par.h2, AUC) %>%
with(cor(`0.5`, `0.8`))
## [1] 0.9875672
results2 %>%
filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
method %in% c("logit-simple", "PRS-max")) %>%
group_by(par.causal, method) %>%
summarise(AUC_mean = mean(AUC), AUC_boot = boot(AUC, 1e5, mean)) %>%
myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
geom_hline(yintercept = 0.5, linetype = 2) +
geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
position=position_dodge(width=0.9), color = "black", width = 0.2) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: logit simple vs PRS max
results2 %>%
filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
grepl("PRS", method)) %>%
group_by(par.causal, method) %>%
summarise(AUC_mean = mean(AUC),
AUC_boot = boot(AUC, 1e5, mean)) %>%
myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
geom_hline(yintercept = 0.5, linetype = 2) +
geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
position=position_dodge(width=0.9), color = "black", width = 0.2) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: all PRS
# Put all results in a single tibble
results3 <- list.files("results3", full.names = TRUE) %>%
map_dfr(~readRDS(.x)) %>%
as_tibble() %>%
mutate(
par.causal = factor(map_chr(par.causal, ~paste(.x[1], .x[2], sep = " in ")),
levels = c("30 in HLA", paste(3 * 10^(1:3), "in all"))),
AUC = map_dbl(eval, ~bigstatsr::AUC(.x[, 1], .x[, 2]))
)
results3 %>%
filter(par.h2 == 0.8) %>%
plot_results("AUC") +
geom_hline(yintercept = 0.94, color = "blue", linetype = 3) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All AUC results for h2=0.8 and chromosome 6
results3 %>%
filter(par.h2 == 0.8) %>%
group_by_at(c(vars(starts_with("par")), "num.simu")) %>%
mutate(AUC_rel = AUC / AUC[method == "logit-simple"]) %>%
plot_results(y = "AUC_rel", ylab = "Relative AUC / 'logit-simple'") +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All relative AUC results for h2=0.8 and chromosome 6
results3 %>%
filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
method %in% c("logit-simple", "PRS-max")) %>%
group_by(par.causal, method) %>%
summarise(AUC_mean = mean(AUC), AUC_boot = boot(AUC, 1e5, mean)) %>%
myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
geom_hline(yintercept = 0.5, linetype = 2) +
geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
position=position_dodge(width=0.9), color = "black", width = 0.2) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: logit simple vs PRS max for chromosome 6
results3 %>%
filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
grepl("PRS", method)) %>%
group_by(par.causal, method) %>%
summarise(AUC_mean = mean(AUC),
AUC_boot = boot(AUC, 1e5, mean)) %>%
myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
geom_hline(yintercept = 0.5, linetype = 2) +
geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
position=position_dodge(width=0.9), color = "black", width = 0.2) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: all PRS for chromosome 6
knitr::include_graphics("figures/celiac-man.png")
Manhanttan plot for Celiac
knitr::include_graphics("figures/celiac-regpath.png")
Regularization paths for the methods. For LR, line in the result given by CMSA.
results2 %>%
filter(
# method %in% c("logit-simple", "PRS-max"),
grepl("all", par.causal)
) %>%
mutate(M = readr::parse_number(par.causal)) %>%
group_by(M, par.dist, par.h2, par.model, method) %>%
summarise(AUC = mean(AUC)) %>%
# filter(par.h2 == 0.8, par.dist == "gaussian", par.model == "simple") %>%
filter(par.h2 == 0.8) %>%
myggplot(aes(M, AUC, color = method)) +
geom_line() +
geom_point(size = 2) +
facet_grid(par.dist ~ par.model) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10) +
scale_x_sqrt(breaks = c(30, 300, 1200, 3000)) +
labs(x = "Number of causal SNPs (sqrt-scale)")
AUC as function of M. For all chromosomes.
results3 %>%
filter(
# method %in% c("logit-simple", "PRS-max"),
grepl("all", par.causal)
) %>%
mutate(M = readr::parse_number(par.causal)) %>%
group_by(M, par.dist, par.h2, par.model, method) %>%
summarise(AUC = mean(AUC)) %>%
# filter(par.h2 == 0.8, par.dist == "gaussian", par.model == "simple") %>%
filter(par.h2 == 0.8) %>%
myggplot(aes(M, AUC, color = method)) +
geom_line() +
geom_point(size = 2) +
facet_grid(par.dist ~ par.model) +
scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10) +
scale_x_sqrt(breaks = c(30, 300, 1200, 3000)) +
labs(x = "Number of causal SNPs (sqrt-scale)")
AUC as function of M. For chromosome 6.
bind_rows(
bind_cols(results2, simu = rep("all", nrow(results2))),
bind_cols(results3, simu = rep("chr6", nrow(results3)))
) %>%
filter(par.h2 == 0.8, par.dist == "gaussian", method != "logit-triple") %>%
myggplot() +
geom_boxplot(aes(method, AUC, fill = simu, color = simu), alpha = 0.3) +
theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
facet_grid(par.model ~ par.causal) +
theme(strip.text.x = element_text(size = rel(2)),
strip.text.y = element_text(size = rel(2))) +
scale_colour_brewer(palette = "Set1") +
scale_fill_brewer(palette = "Set1") +
geom_hline(yintercept = 0.94, color = "blue", linetype = 2)
AUCs for h2=0.8, dist=gaussian, comparing all chromosomes and chromosome 6.
knitr::include_graphics("effects.png")
Size of effects GWAS vs logistic, for Celiac
knitr::include_graphics("preds-density2.png")
Density of scores from logistic regression by pop.
knitr::include_graphics("preds-density3.png")
Density of scores from logistic regression by pop and genotype.
knitr::include_graphics("perc_cases2.png")
Percent of controls (errors) in 199 (homoz bad in test).
knitr::include_graphics("gad-pred.png")
Projection on test set of score of Gad, train on same dataset